from celery import shared_task
from celery_client.app import app as celery_client
from django.utils import timezone

import os
import django

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'zmm.settings')
django.setup()

from load_planner.models import Task, TaskStatus
from common.models import Tenant
from django.db import connection, transaction

def retrieve_and_lock_task(id):
    """
    查询并锁定该任务，防止此期间有其他操作在更新此任务
    确保Celery Task的操作原子性
    """
    return (
        Task
        .objects
        .select_for_update()
        .filter(id=id)
        .first()
    )

# Python装饰器
from functools import wraps
def with_tenant(celery_task):
    @wraps(celery_task)
    def decor(payload, *args, **kwargs):
        # 获取租户的schema，并将当前上下文中的目标schema设置为该租户schema
        tenant_id = payload['tenant_id']
        tenant = Tenant.objects.get(id=tenant_id)

        prev_schema = connection.schemas
        connection.set_schema(tenant.schema_name)

        # 从这开始运行celery task （比如 on_task_submitted ）
        with transaction.atomic():
            results = celery_task(payload, *args, **kwargs)

        connection.set_schema(*prev_schema)
        return results
    return decor


@shared_task(name='internal.do_submit')
@with_tenant
def on_task_submitted(payload):
    # 解析payload，获得要提交的任务数据实例
    task = retrieve_and_lock_task(payload['task_id'])

    # [数据转换] 生成计算端Celery Worker需要的payload
    new_payload = {
        'task_id': task.id,
        'tenant_id': payload['tenant_id'],
        'config': task.config,
        'input_dataset': task.input_dataset
    }

    # 提交计算任务
    # 由于计算端的celery task响应函数定义在远端，需使用send_task来提交
    celery_client.send_task(
        'zmm.do_compute',
        args=[new_payload]
    )

    # 记录该任务的实际提交时间
    task.status = TaskStatus.SUBMITTED.value
    task.submitted_at = timezone.now()
    task.save()
    print('任务已提交')


@shared_task(name='zmm.do_start')
@with_tenant
def on_task_started(payload):
    task = retrieve_and_lock_task(payload['task_id'])
    task.status = TaskStatus.RUNNING.value
    task.save()
    print('任务运行中')


@shared_task(name='zmm.do_finish')
@with_tenant
def on_task_finished(payload):
    task = retrieve_and_lock_task(payload['task_id'])
    task.status = TaskStatus.FINISHED.value
    task.finished_at = timezone.now()
    task.result_dataset = payload['result_dataset']
    task.save()
    print('任务已结束')


@shared_task(name='zmm.do_fail')
@with_tenant
def on_task_error(payload):
    task = retrieve_and_lock_task(payload['task_id'])
    task.status = TaskStatus.ERROR.value
    task.save()
    print('任务已失败')
